{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "class Model4_1(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_1, self).__init__()\n",
    "        self.lin1 = nn.Linear(784, 100) \n",
    "        self.relu = nn.ReLU()\n",
    "        self.lin2 = nn.Linear(100, 10)        \n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.lin1(x)\n",
    "        out = self.relu(out)\n",
    "        out = self.lin2(out)\n",
    "        return out\n",
    "\n",
    "model4_1 = Model4_1() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model4_2(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_2, self).__init__()\n",
    "        self.lin1 = nn.Linear(784, 100) \n",
    "        self.tanh = nn.Tanh()\n",
    "        self.lin2 = nn.Linear(100, 10)        \n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.lin1(x)\n",
    "        out = self.tanh(out)\n",
    "        out = self.lin2(out)\n",
    "        return out\n",
    "\n",
    "model4_2 = Model4_2() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "class Model4_3(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_3, self).__init__()\n",
    "        self.lin1 = nn.Linear(784, 100) \n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "        self.lin2 = nn.Linear(100, 10)        \n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.lin1(x)\n",
    "        out = self.sigmoid(out)\n",
    "        out = self.lin2(out)\n",
    "        return out\n",
    "\n",
    "model4_3 = Model4_3() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "batchSize = 100\n",
    "trainSet = torchvision.datasets.MNIST(root='./data', train = True, transform=transforms.ToTensor(), download=True)\n",
    "trainLoader = torch.utils.data.DataLoader(dataset=trainSet, batch_size=batchSize, shuffle = True)\n",
    "testSet = torchvision.datasets.MNIST(root='./data', train = False, transform=transforms.ToTensor(), download=True)\n",
    "testLoader = torch.utils.data.DataLoader(dataset=testSet, batch_size=batchSize, shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "import time\n",
    "\n",
    "def benchmark(trainLoader, model, epochs=1, lr=0.01):\n",
    "    model.__init__()\n",
    "    start=time.time()\n",
    "    optimiser = optim.SGD(model.parameters(), lr=lr)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    for epoch in range(epochs):\n",
    "        for i, (images, labels) in enumerate(trainLoader):\n",
    "            optimiser.zero_grad()\n",
    "            outputs = model(images.view(-1, 28*28))\n",
    "            loss = criterion(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimiser.step()\n",
    "    print('Accuracy: {0:.4f}'.format(accuracy(testLoader,model)))\n",
    "    print('Training time: {0:.2f}'.format(time.time() - start))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def accuracy(testLoader,model):    \n",
    "    correct, total = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for data in testLoader:\n",
    "            images, labels = data\n",
    "            outputs = model(images.view(-1, 28*28))\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()      \n",
    "    return(correct / total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ReLU actiavtion:\n",
      "Accuracy: 0.9581\n",
      "Training time: 29.28\n",
      "Tanh activation\n",
      "Accuracy: 0.9507\n",
      "Training time: 30.02\n",
      "sigmoid activation\n",
      "Accuracy: 0.9182\n",
      "Training time: 29.49\n"
     ]
    }
   ],
   "source": [
    "print('ReLU activation:')\n",
    "benchmark(trainLoader, model4_1, epochs=5, lr = 0.1)\n",
    "print('Tanh activation')\n",
    "benchmark(trainLoader, model4_2, epochs=5, lr = 0.1)\n",
    "print('sigmoid activation')\n",
    "benchmark(trainLoader, model4_3, epochs=5, lr = 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model4_4(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_4, self).__init__()\n",
    "        self.lin1 = nn.Linear(784, 100) \n",
    "        self.relu1 = nn.ReLU()\n",
    "        self.lin2 = nn.Linear(100, 50)\n",
    "        self.relu2 = nn.ReLU()\n",
    "        self.lin3 = nn.Linear(50, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.lin1(x)\n",
    "        out = self.relu1(out)\n",
    "        out = self.lin2(out)\n",
    "        out = self.relu2(out)\n",
    "        out = self.lin3(out)\n",
    "        \n",
    "        return out\n",
    "\n",
    "model4_4 = Model4_4() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model4_4b(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_4b, self).__init__()\n",
    "        self.layer1=nn.Sequential(nn.Linear(784, 100), \n",
    "            nn.ReLU())\n",
    "        self.layer2=nn.Sequential(nn.Linear(100, 50),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(50, 10))\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)                          \n",
    "        return out\n",
    "\n",
    "model4_4b = Model4_4b() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9118\n",
      "Training time: 6.88\n"
     ]
    }
   ],
   "source": [
    "benchmark(trainLoader, model4_4b, epochs=1, lr = 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model4_6(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_6, self).__init__()\n",
    "        self.lin1 = nn.Linear(784, 500) \n",
    "        self.relu1 = nn.ReLU()\n",
    "        self.lin2 = nn.Linear(500,300)\n",
    "        self.relu2 = nn.ReLU()\n",
    "        self.lin3 = nn.Linear(300, 100)\n",
    "        self.relu3 = nn.ReLU()\n",
    "        self.lin4 = nn.Linear(100, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.lin1(x)\n",
    "        out = self.relu1(out)\n",
    "        out = self.lin2(out)\n",
    "        out = self.relu2(out)\n",
    "        out = self.lin3(out)\n",
    "        out = self.relu3(out)\n",
    "        out = self.lin4(out)\n",
    "        \n",
    "        return out\n",
    "\n",
    "model4_6 = Model4_6() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model4_5(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_5, self).__init__()\n",
    "        self.lin1 = nn.Linear(784, 300) \n",
    "        self.relu1 = nn.ReLU()\n",
    "        self.lin2 = nn.Linear(300, 100)\n",
    "        self.relu2 = nn.ReLU()\n",
    "        self.lin3 = nn.Linear(100, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.lin1(x)\n",
    "        out = self.relu1(out)\n",
    "        out = self.lin2(out)\n",
    "        out = self.relu2(out)\n",
    "        out = self.lin3(out)\n",
    "        \n",
    "        return out\n",
    "\n",
    "model4_5 = Model4_5() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 Linear layers 100 - 50 - 10:\n",
      "Accuracy: 0.9668\n",
      "Training time: 31.13\n",
      "3 Linear layers 300 -100 - 10:\n",
      "Accuracy: 0.9694\n",
      "Training time: 37.79\n",
      "4 Linear layers 500 - 300 -100 - 10 :\n",
      "Accuracy: 0.9727\n",
      "Training time: 49.87\n"
     ]
    }
   ],
   "source": [
    "print('3 Linear layers 100 - 50 - 10:')\n",
    "benchmark(trainLoader, model4_4, epochs=5, lr = 0.1)\n",
    "print('3 Linear layers 300 -100 - 10:')\n",
    "benchmark(trainLoader, model4_5, epochs=5, lr = 0.1)\n",
    "print('4 Linear layers 500 - 300 -100 - 10 :')\n",
    "benchmark(trainLoader, model4_6, epochs=5, lr = 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
